import torch
import numpy as np
import matplotlib.pyplot as plt
from itertools import cycle


def contour(x, y, Z,  trajectories, args):
    X, Y = np.meshgrid(x, y)
    plt.figure(figsize=(8, 6))
    contour_plot = plt.contour(X, Y, Z, cmap='summer', levels=np.arange(args.vmin, args.vmax, args.vlevel))
    plt.clabel(contour_plot, inline=1, fontsize=8)
    if trajectories:
        line_styles = ['-', '--', '-.', ':']  
        colors = ['r', 'g', 'b', 'c', 'm', 'y', 'k']   
        style_cycle = cycle(line_styles)
        color_cycle = cycle(colors)
        for name, path in trajectories.items():
            path = np.array(path)
            plt.plot(path[0], path[1], linestyle=next(style_cycle), color=next(color_cycle), label=f'{name}', alpha=0.9)
    plt.title('line')
    plt.xlabel('X-axis')
    plt.ylabel('Y-axis')
    ax = plt.gca()  # 获取当前坐标轴对象
    ax.set_aspect(1.0 * (np.ptp(x) / np.ptp(y)))  # 调整y轴比例
    plt.legend()
    plt.savefig(f'../image/{args.task_name}/{args.id}_contour_{args.xmin}_{args.xmax}_{args.xnum}_{args.ymin}_{args.ymax}_{args.ynum}_{args.vmin}_{args.vmax}_{args.vlevel}.png')  